{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import openpifpaf\n",
    "import openpifpaf.logs\n",
    "openpifpaf.show.Canvas.show = True\n",
    "\n",
    "import torch\n",
    "torch.ops.openpifpaf.set_quiet(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cifar10\n",
    "\n",
    "This page gives a quick introduction to OpenPifPaf's Cifar10 plugin that is part of `openpifpaf.plugins`.\n",
    "It demonstrates the plugin architecture.\n",
    "There already is a nice dataset for CIFAR10 in `torchvision` and a related [PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html). \n",
    "The plugin adds a `DataModule` that uses this dataset.\n",
    "Let's start with them setup for this notebook and registering all available OpenPifPaf plugins:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(openpifpaf.plugin.REGISTERED.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we configure and instantiate the Cifar10 datamodule and look at the configured head metas:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# configure \n",
    "openpifpaf.plugins.cifar10.datamodule.Cifar10.debug = True \n",
    "openpifpaf.plugins.cifar10.datamodule.Cifar10.batch_size = 1\n",
    "\n",
    "# instantiate and inspect\n",
    "datamodule = openpifpaf.plugins.cifar10.datamodule.Cifar10()\n",
    "datamodule.set_loader_workers(0)  # no multi-processing to see debug outputs in main thread\n",
    "datamodule.head_metas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see here that CIFAR10 is being treated as a detection dataset (`CifDet`) and has 10 categories.\n",
    "To create a network, we use the `factory()` function that takes the name of the base network `cifar10net` and the list of head metas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = openpifpaf.network.Factory(base_name='cifar10net').factory(head_metas=datamodule.head_metas)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can inspect the training data that is returned from `datamodule.train_loader()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# configure visualization\n",
    "openpifpaf.visualizer.Base.set_all_indices(['cifdet:9:regression'])  # category 9 = truck\n",
    "\n",
    "# Create a wrapper for a data loader that iterates over a set of matplotlib axes.\n",
    "# The only purpose is to set a different matplotlib axis before each call to \n",
    "# retrieve the next image from the data_loader so that it produces multiple\n",
    "# debug images in one canvas side-by-side.\n",
    "def loop_over_axes(axes, data_loader):\n",
    "    previous_common_ax = openpifpaf.visualizer.Base.common_ax\n",
    "    train_loader_iter = iter(data_loader)\n",
    "    for ax in axes.reshape(-1):\n",
    "        openpifpaf.visualizer.Base.common_ax = ax\n",
    "        yield next(train_loader_iter, None)\n",
    "    openpifpaf.visualizer.Base.common_ax = previous_common_ax\n",
    "\n",
    "# create a canvas and loop over the first few entries in the training data\n",
    "with openpifpaf.show.canvas(ncols=6, nrows=3, figsize=(10, 5)) as axs:\n",
    "    for images, targets, meta in loop_over_axes(axs, datamodule.train_loader()):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "We train a very small network, `cifar10net`, for only one epoch. Afterwards, we will investigate its predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "python -m openpifpaf.train \\\n",
    "  --dataset=cifar10 --basenet=cifar10net --log-interval=50 \\\n",
    "  --epochs=3 --lr=0.0003 --momentum=0.95 --batch-size=16 \\\n",
    "  --lr-warm-up-epochs=0.1 --lr-decay 2.0 2.5 --lr-decay-epochs=0.1 \\\n",
    "  --loader-workers=2 --output=cifar10_tutorial.pkl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Training Logs\n",
    "\n",
    "You can create a set of plots from the command line with `python -m openpifpaf.logs cifar10_tutorial.pkl.log`. You can also overlay multiple runs. Below we call the plotting code from that command directly to show the output in this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "openpifpaf.logs.Plots(['cifar10_tutorial.pkl.log']).show_all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction\n",
    "\n",
    "First using CLI:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "python -m openpifpaf.predict --checkpoint cifar10_tutorial.pkl.epoch003 images/cifar10_*.png --seed-threshold=0.1 --json-output . --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash \n",
    "cat cifar10_*.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_cpu, _ = openpifpaf.network.Factory(checkpoint='cifar10_tutorial.pkl.epoch003').factory()\n",
    "preprocess = openpifpaf.transforms.Compose([\n",
    "    openpifpaf.transforms.NormalizeAnnotations(),\n",
    "    openpifpaf.transforms.CenterPadTight(16),\n",
    "    openpifpaf.transforms.EVAL_TRANSFORM,\n",
    "])\n",
    "\n",
    "openpifpaf.decoder.utils.CifDetSeeds.set_threshold(0.3)\n",
    "decode = openpifpaf.decoder.factory([hn.meta for hn in net_cpu.head_nets])\n",
    "\n",
    "data = openpifpaf.datasets.ImageList([\n",
    "    'images/cifar10_airplane4.png',\n",
    "    'images/cifar10_automobile10.png',\n",
    "    'images/cifar10_ship7.png',\n",
    "    'images/cifar10_truck8.png',\n",
    "], preprocess=preprocess)\n",
    "for image, _, meta in data:\n",
    "    predictions = decode.batch(net_cpu, image.unsqueeze(0))[0]\n",
    "    print(['{} {:.0%}'.format(pred.category, pred.score) for pred in predictions])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "I selected the above images, because their category is clear to me. There are images in cifar10 where it is more difficult to tell what the category is and so it is probably also more difficult for a neural network.\n",
    "\n",
    "Therefore, we should run a proper quantitative evaluation with `openpifpaf.eval`. It stores its output as a json file, so we print that afterwards."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "python -m openpifpaf.eval --checkpoint cifar10_tutorial.pkl.epoch003 --dataset=cifar10 --seed-threshold=0.1 --instance-threshold=0.1 --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "python -m json.tool cifar10_tutorial.pkl.epoch003.eval-cifar10.stats.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that some categories like \"plane\", \"car\" and \"ship\" are learned quickly whereas as others are learned poorly (e.g. \"bird\"). The poor performance is not surprising as we trained our network for a few epochs only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "ea6946363a43e80d241452ab397f4c58bdd3d2517da174158e9c46ce6717422a"
  },
  "kernelspec": {
   "display_name": "Python 3.9.6 64-bit ('venv3': venv)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "metadata": {
   "interpreter": {
    "hash": "ea6946363a43e80d241452ab397f4c58bdd3d2517da174158e9c46ce6717422a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
